{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ISubpr_SSsiM"
   },
   "source": [
    "##### Copyright 2020 The TensorFlow Authors.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "colab": {},
    "colab_type": "code",
    "id": "3jTMb1dySr3V"
   },
   "outputs": [],
   "source": [
    "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "6DWfyNThSziV"
   },
   "source": [
    "# Better performance with tf.function\n",
    "\n",
    "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/function\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/function.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/function.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/function.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "J122XQYG7W6w"
   },
   "source": [
    "In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier\n",
    "and faster), but this can come at the expense of performance and deployability.\n",
    "\n",
    "You can use `tf.function` to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use `SavedModel`.\n",
    "\n",
    "This guide will help you conceptualize how `tf.function` works under the hood so you can use it effectively.\n",
    "\n",
    "The main takeaways and recommendations are:\n",
    "\n",
    "- Debug in eager mode, then decorate with `@tf.function`.\n",
    "- Don't rely on Python side effects like object mutation or list appends.\n",
    "- `tf.function` works best with TensorFlow ops; NumPy and Python calls are converted to constants.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SjvqpgepHJPd"
   },
   "source": [
    "## Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "otIdN1TS8N7S"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "I0xDjO4SHLUD"
   },
   "source": [
    "Define a helper function to demonstrate the kinds of errors you might encounter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "D25apou9IOXa"
   },
   "outputs": [],
   "source": [
    "import traceback\n",
    "import contextlib\n",
    "\n",
    "# Some helper code to demonstrate the kinds of errors you might encounter.\n",
    "@contextlib.contextmanager\n",
    "def assert_raises(error_class):\n",
    "  try:\n",
    "    yield\n",
    "  except error_class as e:\n",
    "    print('Caught expected exception \\n  {}:'.format(error_class))\n",
    "    traceback.print_exc(limit=2)\n",
    "  except Exception as e:\n",
    "    raise e\n",
    "  else:\n",
    "    raise Exception('Expected {} to be raised but no error was raised!'.format(\n",
    "        error_class))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WPSfepzTHThq"
   },
   "source": [
    "## Basics"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CNwYTIJ8r56W"
   },
   "source": [
    "### Usage\n",
    "\n",
    "A `Function` you define is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SbtT1-Wm70F2"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(2, 2), dtype=float32, numpy=\n",
       "array([[2., 2.],\n",
       "       [2., 2.]], dtype=float32)>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def add(a, b):\n",
    "  return a + b\n",
    "\n",
    "add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uP-zUelB8DbX"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(), dtype=float32, numpy=1.0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "v = tf.Variable(1.0)\n",
    "with tf.GradientTape() as tape:\n",
    "  result = add(v, 1.0)\n",
    "tape.gradient(result, v)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ocWZvqrmHnmX"
   },
   "source": [
    "You can use `Function`s inside other `Function`s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "l5qRjdbBVdU6"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(3, 2), dtype=float32, numpy=\n",
       "array([[3., 3.],\n",
       "       [3., 3.],\n",
       "       [3., 3.]], dtype=float32)>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@tf.function\n",
    "def dense_layer(x, w, b):\n",
    "  return add(tf.matmul(x, w), b)\n",
    "\n",
    "dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "piBhz7gYsHqU"
   },
   "source": [
    "`Function`s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zuXt4wRysI03"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Eager conv: 0.004499297589063644\n",
      "Function conv: 0.002592983655631542\n",
      "Note how there's not much difference in performance for convolutions\n"
     ]
    }
   ],
   "source": [
    "import timeit\n",
    "conv_layer = tf.keras.layers.Conv2D(100, 3)\n",
    "\n",
    "@tf.function\n",
    "def conv_fn(image):\n",
    "  return conv_layer(image)\n",
    "\n",
    "image = tf.zeros([1, 200, 200, 100])\n",
    "# warm up\n",
    "conv_layer(image); conv_fn(image)\n",
    "print(\"Eager conv:\", timeit.timeit(lambda: conv_layer(image), number=10))\n",
    "print(\"Function conv:\", timeit.timeit(lambda: conv_fn(image), number=10))\n",
    "print(\"Note how there's not much difference in performance for convolutions\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uZ4Do2AV80cO"
   },
   "source": [
    "### Tracing\n",
    "\n",
    "Python's dynamic typing means that you can call functions with a variety of argument types, and Python can do something different in each scenario.\n",
    "\n",
    "Yet, to create a TensorFlow Graph, static `dtypes` and shape dimensions are required. `tf.function` bridges this gap by wrapping a Python function to create a `Function` object. Based on the given inputs, the `Function` selects the appropriate graph for the given inputs, retracing the Python function as necessary. Once you understand why and when tracing happens, it's much easier to use `tf.function` effectively!\n",
    "\n",
    "You can call a `Function` with arguments of different types to see this polymorphic behavior in action."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "kojmJrgq8U9v"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing with Tensor(\"a:0\", shape=(), dtype=int32)\n",
      "tf.Tensor(2, shape=(), dtype=int32)\n",
      "\n",
      "Tracing with Tensor(\"a:0\", shape=(), dtype=float32)\n",
      "tf.Tensor(2.2, shape=(), dtype=float32)\n",
      "\n",
      "Tracing with Tensor(\"a:0\", shape=(), dtype=string)\n",
      "tf.Tensor(b'aa', shape=(), dtype=string)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def double(a):\n",
    "  print(\"Tracing with\", a)\n",
    "  return a + a\n",
    "\n",
    "print(double(tf.constant(1)))\n",
    "print()\n",
    "print(double(tf.constant(1.1)))\n",
    "print()\n",
    "print(double(tf.constant(\"a\")))\n",
    "print()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QPfouGUQrcNb"
   },
   "source": [
    "Note that if you repeatedly call a `Function` with the same argument type, TensorFlow will reuse a previously traced graph, as the generated graph would be identical."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hFccbWFRrsBp"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tf.Tensor(b'bb', shape=(), dtype=string)\n"
     ]
    }
   ],
   "source": [
    "# This doesn't print 'Tracing with ...'\n",
    "print(double(tf.constant(\"b\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rKQ92VEWI7n8"
   },
   "source": [
    "So far, you've seen that `tf.function` creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:\n",
    "\n",
    "- A `tf.Graph` is the raw, language-agnostic, portable representation of your computation.\n",
    "- A `ConcreteFunction` is an eagerly-executing wrapper around a `tf.Graph`.\n",
    "- A `Function` manages a cache of `ConcreteFunction`s and picks the right one for your inputs.\n",
    "- `tf.function` wraps a Python function, returning a `Function` object.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5f05Vr_YBUCz"
   },
   "source": [
    "## AutoGraph Transformations\n",
    "\n",
    "AutoGraph is a library that is on by default in `tf.function`, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like `if`, `for`, `while`.\n",
    "\n",
    "TensorFlow ops like `tf.cond` and `tf.while_loop` continue to work, but control flow is often easier to write and understand when written in Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "yCQTtTPTW3WF"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.282950401 0.329635143 0.596900582 0.105734468 0.23939836]\n",
      "[0.275633544 0.318192899 0.534840405 0.105342194 0.234927401]\n",
      "[0.268859029 0.307872 0.489072412 0.104954258 0.230698779]\n",
      "[0.262562841 0.298499912 0.453479916 0.104570583 0.226691335]\n",
      "[0.256691068 0.289939225 0.424755305 0.104191087 0.222886384]\n",
      "[0.251197964 0.282078862 0.400928944 0.103815697 0.219267413]\n",
      "[0.246044427 0.274828017 0.380743533 0.103444338 0.215819716]\n",
      "[0.241196781 0.268111557 0.363353 0.103076935 0.212530166]\n",
      "[0.236625835 0.261866748 0.348164022 0.102713421 0.209387019]\n",
      "[0.232306182 0.256040722 0.334746301 0.102353729 0.206379712]\n",
      "[0.228215575 0.250588566 0.322779059 0.101997793 0.203498706]\n",
      "[0.224334419 0.245471835 0.31201759 0.101645552 0.20073539]\n",
      "[0.220645383 0.240657419 0.302271456 0.101296932 0.19808197]\n",
      "[0.217133105 0.236116618 0.293389916 0.10095188 0.195531324]\n",
      "[0.213783875 0.231824398 0.285251886 0.100610331 0.193076968]\n",
      "[0.210585445 0.227758825 0.277758807 0.100272231 0.190713]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: shape=(5,), dtype=float32, numpy=\n",
       "array([0.2075268 , 0.2239006 , 0.27082953, 0.09993751, 0.18843399],\n",
       "      dtype=float32)>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Simple loop\n",
    "\n",
    "@tf.function\n",
    "def f(x):\n",
    "  while tf.reduce_sum(x) > 1:\n",
    "    tf.print(x)\n",
    "    x = tf.tanh(x)\n",
    "  return x\n",
    "\n",
    "f(tf.random.uniform([5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KxwJ8znPI0Cg"
   },
   "source": [
    "If you're curious you can inspect the code autograph generates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jlQD1ffRXJhl"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def tf__f(x):\n",
      "    do_return = False\n",
      "    retval_ = ag__.UndefinedReturnValue()\n",
      "    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:\n",
      "\n",
      "        def get_state():\n",
      "            return (x,)\n",
      "\n",
      "        def set_state(loop_vars):\n",
      "            nonlocal x\n",
      "            (x,) = loop_vars\n",
      "\n",
      "        def loop_body():\n",
      "            nonlocal x\n",
      "            ag__.converted_call(tf.print, (x,), None, fscope)\n",
      "            x = ag__.converted_call(tf.tanh, (x,), None, fscope)\n",
      "\n",
      "        def loop_test():\n",
      "            return (ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1)\n",
      "        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})\n",
      "        try:\n",
      "            do_return = True\n",
      "            retval_ = fscope.mark_return_value(x)\n",
      "        except:\n",
      "            do_return = False\n",
      "            raise\n",
      "    (do_return,)\n",
      "    return ag__.retval(retval_)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print(tf.autograph.to_code(f.python_function))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xgKmkrNTZSyz"
   },
   "source": [
    "### Conditionals\n",
    "\n",
    "AutoGraph will convert some `if <condition>` statements into the equivalent `tf.cond` calls. This substitution is made if `<condition>` is a Tensor. Otherwise, the `if` statement is executed as a Python conditional.\n",
    "\n",
    "A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.\n",
    "\n",
    "`tf.cond` traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; see [AutoGraph tracing effects](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process) for more."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "BOQl8PMq2Sf3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tracing for loop\n",
      "Tracing fizzbuzz branch\n",
      "Tracing fizz branch\n",
      "Tracing buzz branch\n",
      "Tracing default branch\n",
      "1\n",
      "2\n",
      "fizz\n",
      "4\n",
      "buzz\n",
      "1\n",
      "2\n",
      "fizz\n",
      "4\n",
      "buzz\n",
      "fizz\n",
      "7\n",
      "8\n",
      "fizz\n",
      "buzz\n",
      "11\n",
      "fizz\n",
      "13\n",
      "14\n",
      "fizzbuzz\n",
      "16\n",
      "17\n",
      "fizz\n",
      "19\n",
      "buzz\n"
     ]
    }
   ],
   "source": [
    "@tf.function\n",
    "def fizzbuzz(n):\n",
    "  for i in tf.range(1, n + 1):\n",
    "    print('Tracing for loop')\n",
    "    if i % 15 == 0:\n",
    "      print('Tracing fizzbuzz branch')\n",
    "      tf.print('fizzbuzz')\n",
    "    elif i % 3 == 0:\n",
    "      print('Tracing fizz branch')\n",
    "      tf.print('fizz')\n",
    "    elif i % 5 == 0:\n",
    "      print('Tracing buzz branch')\n",
    "      tf.print('buzz')\n",
    "    else:\n",
    "      print('Tracing default branch')\n",
    "      tf.print(i)\n",
    "\n",
    "fizzbuzz(tf.constant(5))\n",
    "fizzbuzz(tf.constant(20))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4rBO5AQ15HVC"
   },
   "source": [
    "See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements) for additional restrictions on AutoGraph-converted if statements."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "function.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
